# -*- coding: UTF-8 -*-

import os

import torch
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

from pytorch_lightning.core import LightningModule


class ImageNet224Module(LightningModule):
    def __init__(self, workers, batch_size, data_path):
        super().__init__()
        self.workers = workers
        self.batch_size = batch_size
        self.data_path = data_path
        self.size = 160

    def train_dataloader(self):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

        train_dir = os.path.join(self.data_path, 'train')
        train_dataset = datasets.ImageFolder(
            train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(self.size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        )

        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
        )
        return train_loader

    def val_dataloader(self):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )
        val_dir = os.path.join(self.data_path, 'val')
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                val_dir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])
            ),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.workers,
        )
        return val_loader

    def test_dataloader(self):
        return self.val_dataloader()
